[SPARK-5924] Add the ability to specify withMean or withStd parameters with StandarScaler
The current implementation call the default constructor of mllib.feature.StandarScaler without the possibility to specify withMean or withStd options. Author: jrabary <Jaonary@gmail.com> Closes #4704 from jrabary/master and squashes the following commits: fae8568 [jrabary] style fix 8896b0e [jrabary] Comments fix ef96d73 [jrabary] style fix 8e52607 [jrabary] style fix edd9d48 [jrabary] Fix default param initialization 17e1a76 [jrabary] Fix default param initialization 298f405 [jrabary] Typo fix 45ed914 [jrabary] Add withMean and withStd params to StandarScaler
This commit is contained in:
parent
6fe690d5a8
commit
1be207078c
|
@ -30,7 +30,22 @@ import org.apache.spark.sql.types.{StructField, StructType}
|
|||
/**
|
||||
* Params for [[StandardScaler]] and [[StandardScalerModel]].
|
||||
*/
|
||||
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
|
||||
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
|
||||
|
||||
/**
|
||||
* False by default. Centers the data with mean before scaling.
|
||||
* It will build a dense output, so this does not work on sparse input
|
||||
* and will raise an exception.
|
||||
* @group param
|
||||
*/
|
||||
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
|
||||
|
||||
/**
|
||||
* True by default. Scales the data to unit standard deviation.
|
||||
* @group param
|
||||
*/
|
||||
val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
|
||||
}
|
||||
|
||||
/**
|
||||
* :: AlphaComponent ::
|
||||
|
@ -40,18 +55,27 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
|
|||
@AlphaComponent
|
||||
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
|
||||
|
||||
setDefault(withMean -> false, withStd -> true)
|
||||
|
||||
/** @group setParam */
|
||||
def setInputCol(value: String): this.type = set(inputCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||
|
||||
|
||||
/** @group setParam */
|
||||
def setWithMean(value: Boolean): this.type = set(withMean, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setWithStd(value: Boolean): this.type = set(withStd, value)
|
||||
|
||||
override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
|
||||
transformSchema(dataset.schema, paramMap, logging = true)
|
||||
val map = extractParamMap(paramMap)
|
||||
val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
|
||||
val scaler = new feature.StandardScaler().fit(input)
|
||||
val model = new StandardScalerModel(this, map, scaler)
|
||||
val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd))
|
||||
val scalerModel = scaler.fit(input)
|
||||
val model = new StandardScalerModel(this, map, scalerModel)
|
||||
Params.inheritValues(map, this, model)
|
||||
model
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue