[SPARK-18518][ML] HasSolver supports override
## What changes were proposed in this pull request? 1, make param support non-final with `finalFields` option 2, generate `HasSolver` with `finalFields = false` 3, override `solver` in LiR, GLR, and make MLPC inherit `HasSolver` ## How was this patch tested? existing tests Author: Ruifeng Zheng <ruifengz@foxmail.com> Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #16028 from zhengruifeng/param_non_final.
This commit is contained in:
parent
37ef32e515
commit
e0b047eafe
|
@ -27,13 +27,16 @@ import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
|
|||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol}
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
||||
/** Params for Multilayer Perceptron. */
|
||||
private[classification] trait MultilayerPerceptronParams extends PredictorParams
|
||||
with HasSeed with HasMaxIter with HasTol with HasStepSize {
|
||||
with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver {
|
||||
|
||||
import MultilayerPerceptronClassifier._
|
||||
|
||||
/**
|
||||
* Layer sizes including input size and output size.
|
||||
*
|
||||
|
@ -78,14 +81,10 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
|
|||
* @group expertParam
|
||||
*/
|
||||
@Since("2.0.0")
|
||||
final val solver: Param[String] = new Param[String](this, "solver",
|
||||
final override val solver: Param[String] = new Param[String](this, "solver",
|
||||
"The solver algorithm for optimization. Supported options: " +
|
||||
s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)",
|
||||
ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers))
|
||||
|
||||
/** @group expertGetParam */
|
||||
@Since("2.0.0")
|
||||
final def getSolver: String = $(solver)
|
||||
s"${supportedSolvers.mkString(", ")}. (Default l-bfgs)",
|
||||
ParamValidators.inArray[String](supportedSolvers))
|
||||
|
||||
/**
|
||||
* The initial weights of the model.
|
||||
|
@ -101,7 +100,7 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
|
|||
final def getInitialWeights: Vector = $(initialWeights)
|
||||
|
||||
setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128,
|
||||
solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03)
|
||||
solver -> LBFGS, stepSize -> 0.03)
|
||||
}
|
||||
|
||||
/** Label to vector converter. */
|
||||
|
|
|
@ -80,8 +80,7 @@ private[shared] object SharedParamsCodeGen {
|
|||
" 0)", isValid = "ParamValidators.gt(0)"),
|
||||
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
|
||||
"all instance weights as 1.0"),
|
||||
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
|
||||
"empty, default value is 'auto'", Some("\"auto\"")),
|
||||
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
|
||||
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
|
||||
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
|
||||
|
||||
|
@ -99,6 +98,7 @@ private[shared] object SharedParamsCodeGen {
|
|||
defaultValueStr: Option[String] = None,
|
||||
isValid: String = "",
|
||||
finalMethods: Boolean = true,
|
||||
finalFields: Boolean = true,
|
||||
isExpertParam: Boolean = false) {
|
||||
|
||||
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
|
||||
|
@ -167,6 +167,11 @@ private[shared] object SharedParamsCodeGen {
|
|||
} else {
|
||||
"def"
|
||||
}
|
||||
val fieldStr = if (param.finalFields) {
|
||||
"final val"
|
||||
} else {
|
||||
"val"
|
||||
}
|
||||
|
||||
val htmlCompliantDoc = Utility.escape(doc)
|
||||
|
||||
|
@ -180,7 +185,7 @@ private[shared] object SharedParamsCodeGen {
|
|||
| * Param for $htmlCompliantDoc.
|
||||
| * @group ${groupStr(0)}
|
||||
| */
|
||||
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|
||||
| $fieldStr $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|
||||
|$setDefault
|
||||
| /** @group ${groupStr(1)} */
|
||||
| $methodStr get$Name: $T = $$($name)
|
||||
|
|
|
@ -374,17 +374,15 @@ private[ml] trait HasWeightCol extends Params {
|
|||
}
|
||||
|
||||
/**
|
||||
* Trait for shared param solver (default: "auto").
|
||||
* Trait for shared param solver.
|
||||
*/
|
||||
private[ml] trait HasSolver extends Params {
|
||||
|
||||
/**
|
||||
* Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
|
||||
* Param for the solver algorithm for optimization.
|
||||
* @group param
|
||||
*/
|
||||
final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'")
|
||||
|
||||
setDefault(solver, "auto")
|
||||
val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization")
|
||||
|
||||
/** @group getParam */
|
||||
final def getSolver: String = $(solver)
|
||||
|
|
|
@ -164,7 +164,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
|||
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
|
||||
}
|
||||
|
||||
import GeneralizedLinearRegression._
|
||||
/**
|
||||
* The solver algorithm for optimization.
|
||||
* Supported options: "irls" (iteratively reweighted least squares).
|
||||
* Default: "irls"
|
||||
*
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.3.0")
|
||||
final override val solver: Param[String] = new Param[String](this, "solver",
|
||||
"The solver algorithm for optimization. Supported options: " +
|
||||
s"${supportedSolvers.mkString(", ")}. (Default irls)",
|
||||
ParamValidators.inArray[String](supportedSolvers))
|
||||
|
||||
@Since("2.0.0")
|
||||
override def validateAndTransformSchema(
|
||||
|
@ -350,7 +361,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
|||
*/
|
||||
@Since("2.0.0")
|
||||
def setSolver(value: String): this.type = set(solver, value)
|
||||
setDefault(solver -> "irls")
|
||||
setDefault(solver -> IRLS)
|
||||
|
||||
/**
|
||||
* Sets the link prediction (linear predictor) column name.
|
||||
|
@ -442,6 +453,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
|||
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
|
||||
)
|
||||
|
||||
/** String name for "irls" (iteratively reweighted least squares) solver. */
|
||||
private[regression] val IRLS = "irls"
|
||||
|
||||
/** Set of solvers that GeneralizedLinearRegression supports. */
|
||||
private[regression] val supportedSolvers = Array(IRLS)
|
||||
|
||||
/** Set of family names that GeneralizedLinearRegression supports. */
|
||||
private[regression] lazy val supportedFamilyNames =
|
||||
supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie"
|
||||
|
|
|
@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares
|
|||
import org.apache.spark.ml.PredictorParams
|
||||
import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator
|
||||
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
||||
|
@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel
|
|||
private[regression] trait LinearRegressionParams extends PredictorParams
|
||||
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
|
||||
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
|
||||
with HasAggregationDepth
|
||||
with HasAggregationDepth {
|
||||
|
||||
import LinearRegression._
|
||||
|
||||
/**
|
||||
* The solver algorithm for optimization.
|
||||
* Supported options: "l-bfgs", "normal" and "auto".
|
||||
* Default: "auto"
|
||||
*
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.3.0")
|
||||
final override val solver: Param[String] = new Param[String](this, "solver",
|
||||
"The solver algorithm for optimization. Supported options: " +
|
||||
s"${supportedSolvers.mkString(", ")}. (Default auto)",
|
||||
ParamValidators.inArray[String](supportedSolvers))
|
||||
}
|
||||
|
||||
/**
|
||||
* Linear regression.
|
||||
|
@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
extends Regressor[Vector, LinearRegression, LinearRegressionModel]
|
||||
with LinearRegressionParams with DefaultParamsWritable with Logging {
|
||||
|
||||
import LinearRegression._
|
||||
|
||||
@Since("1.4.0")
|
||||
def this() = this(Identifiable.randomUID("linReg"))
|
||||
|
||||
|
@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
* @group setParam
|
||||
*/
|
||||
@Since("1.6.0")
|
||||
def setSolver(value: String): this.type = {
|
||||
require(Set("auto", "l-bfgs", "normal").contains(value),
|
||||
s"Solver $value was not supported. Supported options: auto, l-bfgs, normal")
|
||||
set(solver, value)
|
||||
}
|
||||
setDefault(solver -> "auto")
|
||||
def setSolver(value: String): this.type = set(solver, value)
|
||||
setDefault(solver -> AUTO)
|
||||
|
||||
/**
|
||||
* Suggested depth for treeAggregate (greater than or equal to 2).
|
||||
|
@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
|
||||
instr.logNumFeatures(numFeatures)
|
||||
|
||||
if (($(solver) == "auto" &&
|
||||
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
|
||||
if (($(solver) == AUTO &&
|
||||
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) {
|
||||
// For low dimensional data, WeightedLeastSquares is more efficient since the
|
||||
// training algorithm only requires one pass through the data. (SPARK-10668)
|
||||
|
||||
|
@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
|
|||
*/
|
||||
@Since("2.1.0")
|
||||
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
|
||||
|
||||
/** String name for "auto". */
|
||||
private[regression] val AUTO = "auto"
|
||||
|
||||
/** String name for "normal". */
|
||||
private[regression] val NORMAL = "normal"
|
||||
|
||||
/** String name for "l-bfgs". */
|
||||
private[regression] val LBFGS = "l-bfgs"
|
||||
|
||||
/** Set of solvers that LinearRegression supports. */
|
||||
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -1265,8 +1265,8 @@ class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaML
|
|||
|
||||
@inherit_doc
|
||||
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
||||
HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable,
|
||||
JavaMLReadable):
|
||||
HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
|
||||
JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Classifier trainer based on the Multilayer Perceptron.
|
||||
Each layer has sigmoid activation function, output layer has softmax.
|
||||
|
@ -1407,20 +1407,6 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
|
|||
"""
|
||||
return self.getOrDefault(self.stepSize)
|
||||
|
||||
@since("2.0.0")
|
||||
def setSolver(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`solver`.
|
||||
"""
|
||||
return self._set(solver=value)
|
||||
|
||||
@since("2.0.0")
|
||||
def getSolver(self):
|
||||
"""
|
||||
Gets the value of solver or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.solver)
|
||||
|
||||
@since("2.0.0")
|
||||
def setInitialWeights(self, value):
|
||||
"""
|
||||
|
|
|
@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
|
|||
.. versionadded:: 1.4.0
|
||||
"""
|
||||
|
||||
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
|
||||
"options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
||||
|
@ -1371,6 +1374,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
|||
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
|
||||
"Only applicable to the Tweedie family.",
|
||||
typeConverter=TypeConverters.toFloat)
|
||||
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
|
||||
"options: irls.", typeConverter=TypeConverters.toString)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
||||
|
|
Loading…
Reference in a new issue