[SPARK-29751][ML] Scalers use Summarizer instead of MultivariateOnlineSummarizer
### What changes were proposed in this pull request? use `ml.Summarizer` instead of `mllib.MultivariateOnlineSummarizer` ### Why are the changes needed? 1, I found that using `ml.Summarizer` is faster than current impl; 2, `mllib.MultivariateOnlineSummarizer` maintain all arrays, while `ml.Summarizer` only maintain necessary arrays 3, using `ml.Summarizer` will avoid vector conversions to `mlllib.Vector` ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes #26393 from zhengruifeng/maxabs_opt. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
8353000b47
commit
854f30ffa8
|
@ -24,9 +24,8 @@ import org.apache.spark.ml.{Estimator, Model}
|
|||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.ml.param.{ParamMap, Params}
|
||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||
import org.apache.spark.ml.stat.Summarizer
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.stat.Statistics
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{StructField, StructType}
|
||||
|
@ -69,14 +68,13 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String)
|
|||
@Since("2.0.0")
|
||||
override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val input = dataset.select($(inputCol)).rdd.map {
|
||||
case Row(v: Vector) => OldVectors.fromML(v)
|
||||
}
|
||||
val summary = Statistics.colStats(input)
|
||||
val minVals = summary.min.toArray
|
||||
val maxVals = summary.max.toArray
|
||||
val n = minVals.length
|
||||
val maxAbs = Array.tabulate(n) { i => math.max(math.abs(minVals(i)), math.abs(maxVals(i))) }
|
||||
|
||||
val Row(max: Vector, min: Vector) = dataset
|
||||
.select(Summarizer.metrics("max", "min").summary(col($(inputCol))).as("summary"))
|
||||
.select("summary.max", "summary.min")
|
||||
.first()
|
||||
|
||||
val maxAbs = Array.tabulate(max.size) { i => math.max(math.abs(min(i)), math.abs(max(i))) }
|
||||
|
||||
copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs).compressed).setParent(this))
|
||||
}
|
||||
|
|
|
@ -24,12 +24,9 @@ import org.apache.spark.ml.{Estimator, Model}
|
|||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
|
||||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||
import org.apache.spark.ml.stat.Summarizer
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.mllib.stat.Statistics
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{StructField, StructType}
|
||||
|
@ -117,12 +114,13 @@ class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") override val uid: String)
|
|||
@Since("2.0.0")
|
||||
override def fit(dataset: Dataset[_]): MinMaxScalerModel = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
|
||||
case Row(v: Vector) => OldVectors.fromML(v)
|
||||
}
|
||||
val summary = Statistics.colStats(input)
|
||||
copyValues(new MinMaxScalerModel(uid, summary.min.compressed,
|
||||
summary.max.compressed).setParent(this))
|
||||
|
||||
val Row(max: Vector, min: Vector) = dataset
|
||||
.select(Summarizer.metrics("max", "min").summary(col($(inputCol))).as("summary"))
|
||||
.select("summary.max", "summary.min")
|
||||
.first()
|
||||
|
||||
copyValues(new MinMaxScalerModel(uid, min.compressed, max.compressed).setParent(this))
|
||||
}
|
||||
|
||||
@Since("1.5.0")
|
||||
|
|
|
@ -24,10 +24,8 @@ import org.apache.spark.ml._
|
|||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.stat.Summarizer
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.mllib.feature.{StandardScaler => OldStandardScaler}
|
||||
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
|
@ -109,13 +107,15 @@ class StandardScaler @Since("1.4.0") (
|
|||
@Since("2.0.0")
|
||||
override def fit(dataset: Dataset[_]): StandardScalerModel = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val input = dataset.select($(inputCol)).rdd.map {
|
||||
case Row(v: Vector) => OldVectors.fromML(v)
|
||||
}
|
||||
val scaler = new OldStandardScaler(withMean = $(withMean), withStd = $(withStd))
|
||||
val scalerModel = scaler.fit(input)
|
||||
copyValues(new StandardScalerModel(uid, scalerModel.std.compressed,
|
||||
scalerModel.mean.compressed).setParent(this))
|
||||
|
||||
val Row(mean: Vector, variance: Vector) = dataset
|
||||
.select(Summarizer.metrics("mean", "variance").summary(col($(inputCol))).as("summary"))
|
||||
.select("summary.mean", "summary.variance")
|
||||
.first()
|
||||
|
||||
val std = Vectors.dense(variance.toArray.map(math.sqrt))
|
||||
|
||||
copyValues(new StandardScalerModel(uid, std.compressed, mean.compressed).setParent(this))
|
||||
}
|
||||
|
||||
@Since("1.4.0")
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
|
|||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.sql.Column
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeArrayData}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
|
||||
import org.apache.spark.sql.functions.lit
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -65,7 +65,7 @@ sealed abstract class SummaryBuilder {
|
|||
* val dataframe = ... // Some dataframe containing a feature column and a weight column
|
||||
* val multiStatsDF = dataframe.select(
|
||||
* Summarizer.metrics("min", "max", "count").summary($"features", $"weight")
|
||||
* val Row(Row(minVec, maxVec, count)) = multiStatsDF.first()
|
||||
* val Row(minVec, maxVec, count) = multiStatsDF.first()
|
||||
* }}}
|
||||
*
|
||||
* If one wants to get a single metric, shortcuts are also available:
|
||||
|
|
Loading…
Reference in a new issue