[SPARK-18518][ML] HasSolver supports override

## What changes were proposed in this pull request?
1, make param support non-final with `finalFields` option
2, generate `HasSolver` with `finalFields = false`
3, override `solver` in LiR, GLR, and make MLPC inherit `HasSolver`

## How was this patch tested?
existing tests

Author: Ruifeng Zheng <ruifengz@foxmail.com>
Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #16028 from zhengruifeng/param_non_final.
This commit is contained in:
Ruifeng Zheng 2017-07-01 15:37:41 +08:00 committed by Yanbo Liang
parent 37ef32e515
commit e0b047eafe
7 changed files with 82 additions and 46 deletions

View file

@ -27,13 +27,16 @@ import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol} import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset import org.apache.spark.sql.Dataset
/** Params for Multilayer Perceptron. */ /** Params for Multilayer Perceptron. */
private[classification] trait MultilayerPerceptronParams extends PredictorParams private[classification] trait MultilayerPerceptronParams extends PredictorParams
with HasSeed with HasMaxIter with HasTol with HasStepSize { with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver {
import MultilayerPerceptronClassifier._
/** /**
* Layer sizes including input size and output size. * Layer sizes including input size and output size.
* *
@ -78,14 +81,10 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
* @group expertParam * @group expertParam
*/ */
@Since("2.0.0") @Since("2.0.0")
final val solver: Param[String] = new Param[String](this, "solver", final override val solver: Param[String] = new Param[String](this, "solver",
"The solver algorithm for optimization. Supported options: " + "The solver algorithm for optimization. Supported options: " +
s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)", s"${supportedSolvers.mkString(", ")}. (Default l-bfgs)",
ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers)) ParamValidators.inArray[String](supportedSolvers))
/** @group expertGetParam */
@Since("2.0.0")
final def getSolver: String = $(solver)
/** /**
* The initial weights of the model. * The initial weights of the model.
@ -101,7 +100,7 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
final def getInitialWeights: Vector = $(initialWeights) final def getInitialWeights: Vector = $(initialWeights)
setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128, setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128,
solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03) solver -> LBFGS, stepSize -> 0.03)
} }
/** Label to vector converter. */ /** Label to vector converter. */

View file

@ -80,8 +80,7 @@ private[shared] object SharedParamsCodeGen {
" 0)", isValid = "ParamValidators.gt(0)"), " 0)", isValid = "ParamValidators.gt(0)"),
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
"all instance weights as 1.0"), "all instance weights as 1.0"),
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " + ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
"empty, default value is 'auto'", Some("\"auto\"")),
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
@ -99,6 +98,7 @@ private[shared] object SharedParamsCodeGen {
defaultValueStr: Option[String] = None, defaultValueStr: Option[String] = None,
isValid: String = "", isValid: String = "",
finalMethods: Boolean = true, finalMethods: Boolean = true,
finalFields: Boolean = true,
isExpertParam: Boolean = false) { isExpertParam: Boolean = false) {
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
@ -167,6 +167,11 @@ private[shared] object SharedParamsCodeGen {
} else { } else {
"def" "def"
} }
val fieldStr = if (param.finalFields) {
"final val"
} else {
"val"
}
val htmlCompliantDoc = Utility.escape(doc) val htmlCompliantDoc = Utility.escape(doc)
@ -180,7 +185,7 @@ private[shared] object SharedParamsCodeGen {
| * Param for $htmlCompliantDoc. | * Param for $htmlCompliantDoc.
| * @group ${groupStr(0)} | * @group ${groupStr(0)}
| */ | */
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) | $fieldStr $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|$setDefault |$setDefault
| /** @group ${groupStr(1)} */ | /** @group ${groupStr(1)} */
| $methodStr get$Name: $T = $$($name) | $methodStr get$Name: $T = $$($name)

View file

@ -374,17 +374,15 @@ private[ml] trait HasWeightCol extends Params {
} }
/** /**
* Trait for shared param solver (default: "auto"). * Trait for shared param solver.
*/ */
private[ml] trait HasSolver extends Params { private[ml] trait HasSolver extends Params {
/** /**
* Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. * Param for the solver algorithm for optimization.
* @group param * @group param
*/ */
final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'") val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization")
setDefault(solver, "auto")
/** @group getParam */ /** @group getParam */
final def getSolver: String = $(solver) final def getSolver: String = $(solver)

View file

@ -164,7 +164,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
} }
import GeneralizedLinearRegression._ /**
* The solver algorithm for optimization.
* Supported options: "irls" (iteratively reweighted least squares).
* Default: "irls"
*
* @group param
*/
@Since("2.3.0")
final override val solver: Param[String] = new Param[String](this, "solver",
"The solver algorithm for optimization. Supported options: " +
s"${supportedSolvers.mkString(", ")}. (Default irls)",
ParamValidators.inArray[String](supportedSolvers))
@Since("2.0.0") @Since("2.0.0")
override def validateAndTransformSchema( override def validateAndTransformSchema(
@ -350,7 +361,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
*/ */
@Since("2.0.0") @Since("2.0.0")
def setSolver(value: String): this.type = set(solver, value) def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "irls") setDefault(solver -> IRLS)
/** /**
* Sets the link prediction (linear predictor) column name. * Sets the link prediction (linear predictor) column name.
@ -442,6 +453,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
) )
/** String name for "irls" (iteratively reweighted least squares) solver. */
private[regression] val IRLS = "irls"
/** Set of solvers that GeneralizedLinearRegression supports. */
private[regression] val supportedSolvers = Array(IRLS)
/** Set of family names that GeneralizedLinearRegression supports. */ /** Set of family names that GeneralizedLinearRegression supports. */
private[regression] lazy val supportedFamilyNames = private[regression] lazy val supportedFamilyNames =
supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie" supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie"

View file

@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares
import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.evaluation.RegressionMetrics
@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel
private[regression] trait LinearRegressionParams extends PredictorParams private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
with HasAggregationDepth with HasAggregationDepth {
import LinearRegression._
/**
* The solver algorithm for optimization.
* Supported options: "l-bfgs", "normal" and "auto".
* Default: "auto"
*
* @group param
*/
@Since("2.3.0")
final override val solver: Param[String] = new Param[String](this, "solver",
"The solver algorithm for optimization. Supported options: " +
s"${supportedSolvers.mkString(", ")}. (Default auto)",
ParamValidators.inArray[String](supportedSolvers))
}
/** /**
* Linear regression. * Linear regression.
@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
extends Regressor[Vector, LinearRegression, LinearRegressionModel] extends Regressor[Vector, LinearRegression, LinearRegressionModel]
with LinearRegressionParams with DefaultParamsWritable with Logging { with LinearRegressionParams with DefaultParamsWritable with Logging {
import LinearRegression._
@Since("1.4.0") @Since("1.4.0")
def this() = this(Identifiable.randomUID("linReg")) def this() = this(Identifiable.randomUID("linReg"))
@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
* @group setParam * @group setParam
*/ */
@Since("1.6.0") @Since("1.6.0")
def setSolver(value: String): this.type = { def setSolver(value: String): this.type = set(solver, value)
require(Set("auto", "l-bfgs", "normal").contains(value), setDefault(solver -> AUTO)
s"Solver $value was not supported. Supported options: auto, l-bfgs, normal")
set(solver, value)
}
setDefault(solver -> "auto")
/** /**
* Suggested depth for treeAggregate (greater than or equal to 2). * Suggested depth for treeAggregate (greater than or equal to 2).
@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth) elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
instr.logNumFeatures(numFeatures) instr.logNumFeatures(numFeatures)
if (($(solver) == "auto" && if (($(solver) == AUTO &&
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) {
// For low dimensional data, WeightedLeastSquares is more efficient since the // For low dimensional data, WeightedLeastSquares is more efficient since the
// training algorithm only requires one pass through the data. (SPARK-10668) // training algorithm only requires one pass through the data. (SPARK-10668)
@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
*/ */
@Since("2.1.0") @Since("2.1.0")
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
/** String name for "auto". */
private[regression] val AUTO = "auto"
/** String name for "normal". */
private[regression] val NORMAL = "normal"
/** String name for "l-bfgs". */
private[regression] val LBFGS = "l-bfgs"
/** Set of solvers that LinearRegression supports. */
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS)
} }
/** /**

View file

@ -1265,8 +1265,8 @@ class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaML
@inherit_doc @inherit_doc
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable, HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
JavaMLReadable): JavaMLWritable, JavaMLReadable):
""" """
Classifier trainer based on the Multilayer Perceptron. Classifier trainer based on the Multilayer Perceptron.
Each layer has sigmoid activation function, output layer has softmax. Each layer has sigmoid activation function, output layer has softmax.
@ -1407,20 +1407,6 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
""" """
return self.getOrDefault(self.stepSize) return self.getOrDefault(self.stepSize)
@since("2.0.0")
def setSolver(self, value):
"""
Sets the value of :py:attr:`solver`.
"""
return self._set(solver=value)
@since("2.0.0")
def getSolver(self):
"""
Gets the value of solver or its default value.
"""
return self.getOrDefault(self.solver)
@since("2.0.0") @since("2.0.0")
def setInitialWeights(self, value): def setInitialWeights(self, value):
""" """

View file

@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
.. versionadded:: 1.4.0 .. versionadded:: 1.4.0
""" """
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
"options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
@keyword_only @keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
@ -1371,6 +1374,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
"Only applicable to the Tweedie family.", "Only applicable to the Tweedie family.",
typeConverter=TypeConverters.toFloat) typeConverter=TypeConverters.toFloat)
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
"options: irls.", typeConverter=TypeConverters.toString)
@keyword_only @keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",