[SPARK-29754][ML] LoR/AFT/LiR/SVC use Summarizer instead of MultivariateOnlineSummarizer

### What changes were proposed in this pull request?
1, change the scope of `ml.SummarizerBuffer` and add a method `createSummarizerBuffer` for it, so it can be used as an aggregator like `MultivariateOnlineSummarizer`;
2, In LoR/AFT/LiR/SVC, use Summarizer instead of MultivariateOnlineSummarizer

### Why are the changes needed?
The computation of summary before learning iterations is a bottleneck in high-dimension cases, since `MultivariateOnlineSummarizer` compute much more than needed.
In the [ticket](https://issues.apache.org/jira/browse/SPARK-29754) is an example, with `--driver-memory=4G` LoR will always fail on KDDA dataset. If we swith to `ml.Summarizer`, then `--driver-memory=3G` is enough to train a model.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
existing testsuites & manual test in REPL

Closes #26396 from zhengruifeng/using_SummarizerBuffer.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
zhengruifeng 2019-11-06 18:19:39 +08:00
parent 0dcd739534
commit 5853e8b330
5 changed files with 45 additions and 56 deletions

View file

@ -32,10 +32,9 @@ import org.apache.spark.ml.optim.aggregator.HingeAggregator
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.SummaryBuilderImpl._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.storage.StorageLevel
@ -170,19 +169,15 @@ class LinearSVC @Since("2.2.0") (
instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol,
regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth)
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2))
instances.treeAggregate(
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
val (summarizer, labelSummarizer) = instances.treeAggregate(
(createSummarizerBuffer("mean", "variance", "count"), new MultiClassSummarizer))(
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
c2: (SummarizerBuffer, MultiClassSummarizer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2)),
depth = $(aggregationDepth)
)
instr.logNumExamples(summarizer.count)
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)

View file

@ -34,11 +34,10 @@ import org.apache.spark.ml.optim.aggregator.LogisticAggregator
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.SummaryBuilderImpl._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
@ -501,19 +500,16 @@ class LogisticRegression @Since("1.2.0") (
probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol,
fitIntercept)
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
val (summarizer, labelSummarizer) = instances.treeAggregate(
(createSummarizerBuffer("mean", "variance", "count"), new MultiClassSummarizer))(
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
c2: (SummarizerBuffer, MultiClassSummarizer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2)),
depth = $(aggregationDepth)
)
val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2))
instances.treeAggregate(
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
instr.logNumExamples(summarizer.count)
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)

View file

@ -31,10 +31,9 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.SummaryBuilderImpl._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
@ -215,15 +214,12 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val featuresSummarizer = {
val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
c1.merge(c2)
}
instances.treeAggregate(
new MultivariateOnlineSummarizer
)(seqOp, combOp, $(aggregationDepth))
}
val featuresSummarizer = instances.treeAggregate(
createSummarizerBuffer("mean", "variance", "count"))(
seqOp = (c: SummarizerBuffer, v: AFTPoint) => c.add(v.features),
combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2),
depth = $(aggregationDepth)
)
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
val numFeatures = featuresStd.size

View file

@ -36,12 +36,12 @@ import org.apache.spark.ml.optim.aggregator.{HuberAggregator, LeastSquaresAggreg
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.stat.SummaryBuilderImpl._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
@ -357,20 +357,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val (featuresSummarizer, ySummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
instance: Instance) =>
(c._1.add(instance.features, instance.weight),
c._2.add(Vectors.dense(instance.label), instance.weight))
val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2))
instances.treeAggregate(
(new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)
)(seqOp, combOp, $(aggregationDepth))
}
val (featuresSummarizer, ySummarizer) = instances.treeAggregate(
(createSummarizerBuffer("mean", "variance"),
createSummarizerBuffer("mean", "variance", "count")))(
seqOp = (c: (SummarizerBuffer, SummarizerBuffer), instance: Instance) =>
(c._1.add(instance.features, instance.weight),
c._2.add(Vectors.dense(instance.label), instance.weight)),
combOp = (c1: (SummarizerBuffer, SummarizerBuffer),
c2: (SummarizerBuffer, SummarizerBuffer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2)),
depth = $(aggregationDepth)
)
val yMean = ySummarizer.mean(0)
val rawYStd = math.sqrt(ySummarizer.variance(0))

View file

@ -230,6 +230,11 @@ private[ml] object SummaryBuilderImpl extends Logging {
StructType(fields)
}
private[ml] def createSummarizerBuffer(requested: String*): SummarizerBuffer = {
val (metrics, computeMetrics) = getRelevantMetrics(requested)
new SummarizerBuffer(metrics, computeMetrics)
}
private val vectorUDT = new VectorUDT
/**
@ -277,7 +282,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
private[stat] case object ComputeMax extends ComputeMetric
private[stat] case object ComputeMin extends ComputeMetric
private[stat] class SummarizerBuffer(
private[ml] class SummarizerBuffer(
requestedMetrics: Seq[Metric],
requestedCompMetrics: Seq[ComputeMetric]
) extends Serializable {