[SPARK-18518][ML] HasSolver supports override
## What changes were proposed in this pull request? 1, make param support non-final with `finalFields` option 2, generate `HasSolver` with `finalFields = false` 3, override `solver` in LiR, GLR, and make MLPC inherit `HasSolver` ## How was this patch tested? existing tests Author: Ruifeng Zheng <ruifengz@foxmail.com> Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #16028 from zhengruifeng/param_non_final.
This commit is contained in:
parent
37ef32e515
commit
e0b047eafe
|
@ -27,13 +27,16 @@ import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.LabeledPoint
|
||||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol}
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.sql.Dataset
|
import org.apache.spark.sql.Dataset
|
||||||
|
|
||||||
/** Params for Multilayer Perceptron. */
|
/** Params for Multilayer Perceptron. */
|
||||||
private[classification] trait MultilayerPerceptronParams extends PredictorParams
|
private[classification] trait MultilayerPerceptronParams extends PredictorParams
|
||||||
with HasSeed with HasMaxIter with HasTol with HasStepSize {
|
with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver {
|
||||||
|
|
||||||
|
import MultilayerPerceptronClassifier._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Layer sizes including input size and output size.
|
* Layer sizes including input size and output size.
|
||||||
*
|
*
|
||||||
|
@ -78,14 +81,10 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
|
||||||
* @group expertParam
|
* @group expertParam
|
||||||
*/
|
*/
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
final val solver: Param[String] = new Param[String](this, "solver",
|
final override val solver: Param[String] = new Param[String](this, "solver",
|
||||||
"The solver algorithm for optimization. Supported options: " +
|
"The solver algorithm for optimization. Supported options: " +
|
||||||
s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)",
|
s"${supportedSolvers.mkString(", ")}. (Default l-bfgs)",
|
||||||
ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers))
|
ParamValidators.inArray[String](supportedSolvers))
|
||||||
|
|
||||||
/** @group expertGetParam */
|
|
||||||
@Since("2.0.0")
|
|
||||||
final def getSolver: String = $(solver)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The initial weights of the model.
|
* The initial weights of the model.
|
||||||
|
@ -101,7 +100,7 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
|
||||||
final def getInitialWeights: Vector = $(initialWeights)
|
final def getInitialWeights: Vector = $(initialWeights)
|
||||||
|
|
||||||
setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128,
|
setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128,
|
||||||
solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03)
|
solver -> LBFGS, stepSize -> 0.03)
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Label to vector converter. */
|
/** Label to vector converter. */
|
||||||
|
|
|
@ -80,8 +80,7 @@ private[shared] object SharedParamsCodeGen {
|
||||||
" 0)", isValid = "ParamValidators.gt(0)"),
|
" 0)", isValid = "ParamValidators.gt(0)"),
|
||||||
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
|
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
|
||||||
"all instance weights as 1.0"),
|
"all instance weights as 1.0"),
|
||||||
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
|
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
|
||||||
"empty, default value is 'auto'", Some("\"auto\"")),
|
|
||||||
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
|
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
|
||||||
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
|
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
|
||||||
|
|
||||||
|
@ -99,6 +98,7 @@ private[shared] object SharedParamsCodeGen {
|
||||||
defaultValueStr: Option[String] = None,
|
defaultValueStr: Option[String] = None,
|
||||||
isValid: String = "",
|
isValid: String = "",
|
||||||
finalMethods: Boolean = true,
|
finalMethods: Boolean = true,
|
||||||
|
finalFields: Boolean = true,
|
||||||
isExpertParam: Boolean = false) {
|
isExpertParam: Boolean = false) {
|
||||||
|
|
||||||
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
|
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
|
||||||
|
@ -167,6 +167,11 @@ private[shared] object SharedParamsCodeGen {
|
||||||
} else {
|
} else {
|
||||||
"def"
|
"def"
|
||||||
}
|
}
|
||||||
|
val fieldStr = if (param.finalFields) {
|
||||||
|
"final val"
|
||||||
|
} else {
|
||||||
|
"val"
|
||||||
|
}
|
||||||
|
|
||||||
val htmlCompliantDoc = Utility.escape(doc)
|
val htmlCompliantDoc = Utility.escape(doc)
|
||||||
|
|
||||||
|
@ -180,7 +185,7 @@ private[shared] object SharedParamsCodeGen {
|
||||||
| * Param for $htmlCompliantDoc.
|
| * Param for $htmlCompliantDoc.
|
||||||
| * @group ${groupStr(0)}
|
| * @group ${groupStr(0)}
|
||||||
| */
|
| */
|
||||||
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|
| $fieldStr $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|
||||||
|$setDefault
|
|$setDefault
|
||||||
| /** @group ${groupStr(1)} */
|
| /** @group ${groupStr(1)} */
|
||||||
| $methodStr get$Name: $T = $$($name)
|
| $methodStr get$Name: $T = $$($name)
|
||||||
|
|
|
@ -374,17 +374,15 @@ private[ml] trait HasWeightCol extends Params {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Trait for shared param solver (default: "auto").
|
* Trait for shared param solver.
|
||||||
*/
|
*/
|
||||||
private[ml] trait HasSolver extends Params {
|
private[ml] trait HasSolver extends Params {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
|
* Param for the solver algorithm for optimization.
|
||||||
* @group param
|
* @group param
|
||||||
*/
|
*/
|
||||||
final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'")
|
val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization")
|
||||||
|
|
||||||
setDefault(solver, "auto")
|
|
||||||
|
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
final def getSolver: String = $(solver)
|
final def getSolver: String = $(solver)
|
||||||
|
|
|
@ -164,7 +164,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
|
||||||
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
|
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
import GeneralizedLinearRegression._
|
/**
|
||||||
|
* The solver algorithm for optimization.
|
||||||
|
* Supported options: "irls" (iteratively reweighted least squares).
|
||||||
|
* Default: "irls"
|
||||||
|
*
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
@Since("2.3.0")
|
||||||
|
final override val solver: Param[String] = new Param[String](this, "solver",
|
||||||
|
"The solver algorithm for optimization. Supported options: " +
|
||||||
|
s"${supportedSolvers.mkString(", ")}. (Default irls)",
|
||||||
|
ParamValidators.inArray[String](supportedSolvers))
|
||||||
|
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
override def validateAndTransformSchema(
|
override def validateAndTransformSchema(
|
||||||
|
@ -350,7 +361,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
|
||||||
*/
|
*/
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
def setSolver(value: String): this.type = set(solver, value)
|
def setSolver(value: String): this.type = set(solver, value)
|
||||||
setDefault(solver -> "irls")
|
setDefault(solver -> IRLS)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the link prediction (linear predictor) column name.
|
* Sets the link prediction (linear predictor) column name.
|
||||||
|
@ -442,6 +453,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
|
||||||
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
|
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/** String name for "irls" (iteratively reweighted least squares) solver. */
|
||||||
|
private[regression] val IRLS = "irls"
|
||||||
|
|
||||||
|
/** Set of solvers that GeneralizedLinearRegression supports. */
|
||||||
|
private[regression] val supportedSolvers = Array(IRLS)
|
||||||
|
|
||||||
/** Set of family names that GeneralizedLinearRegression supports. */
|
/** Set of family names that GeneralizedLinearRegression supports. */
|
||||||
private[regression] lazy val supportedFamilyNames =
|
private[regression] lazy val supportedFamilyNames =
|
||||||
supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie"
|
supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie"
|
||||||
|
|
|
@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares
|
||||||
import org.apache.spark.ml.PredictorParams
|
import org.apache.spark.ml.PredictorParams
|
||||||
import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator
|
import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator
|
||||||
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
||||||
|
@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel
|
||||||
private[regression] trait LinearRegressionParams extends PredictorParams
|
private[regression] trait LinearRegressionParams extends PredictorParams
|
||||||
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
|
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
|
||||||
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
|
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
|
||||||
with HasAggregationDepth
|
with HasAggregationDepth {
|
||||||
|
|
||||||
|
import LinearRegression._
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The solver algorithm for optimization.
|
||||||
|
* Supported options: "l-bfgs", "normal" and "auto".
|
||||||
|
* Default: "auto"
|
||||||
|
*
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
@Since("2.3.0")
|
||||||
|
final override val solver: Param[String] = new Param[String](this, "solver",
|
||||||
|
"The solver algorithm for optimization. Supported options: " +
|
||||||
|
s"${supportedSolvers.mkString(", ")}. (Default auto)",
|
||||||
|
ParamValidators.inArray[String](supportedSolvers))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Linear regression.
|
* Linear regression.
|
||||||
|
@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
||||||
extends Regressor[Vector, LinearRegression, LinearRegressionModel]
|
extends Regressor[Vector, LinearRegression, LinearRegressionModel]
|
||||||
with LinearRegressionParams with DefaultParamsWritable with Logging {
|
with LinearRegressionParams with DefaultParamsWritable with Logging {
|
||||||
|
|
||||||
|
import LinearRegression._
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
def this() = this(Identifiable.randomUID("linReg"))
|
def this() = this(Identifiable.randomUID("linReg"))
|
||||||
|
|
||||||
|
@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
||||||
* @group setParam
|
* @group setParam
|
||||||
*/
|
*/
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
def setSolver(value: String): this.type = {
|
def setSolver(value: String): this.type = set(solver, value)
|
||||||
require(Set("auto", "l-bfgs", "normal").contains(value),
|
setDefault(solver -> AUTO)
|
||||||
s"Solver $value was not supported. Supported options: auto, l-bfgs, normal")
|
|
||||||
set(solver, value)
|
|
||||||
}
|
|
||||||
setDefault(solver -> "auto")
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Suggested depth for treeAggregate (greater than or equal to 2).
|
* Suggested depth for treeAggregate (greater than or equal to 2).
|
||||||
|
@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
||||||
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
|
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
|
||||||
instr.logNumFeatures(numFeatures)
|
instr.logNumFeatures(numFeatures)
|
||||||
|
|
||||||
if (($(solver) == "auto" &&
|
if (($(solver) == AUTO &&
|
||||||
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
|
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) {
|
||||||
// For low dimensional data, WeightedLeastSquares is more efficient since the
|
// For low dimensional data, WeightedLeastSquares is more efficient since the
|
||||||
// training algorithm only requires one pass through the data. (SPARK-10668)
|
// training algorithm only requires one pass through the data. (SPARK-10668)
|
||||||
|
|
||||||
|
@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
|
||||||
*/
|
*/
|
||||||
@Since("2.1.0")
|
@Since("2.1.0")
|
||||||
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
|
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
|
||||||
|
|
||||||
|
/** String name for "auto". */
|
||||||
|
private[regression] val AUTO = "auto"
|
||||||
|
|
||||||
|
/** String name for "normal". */
|
||||||
|
private[regression] val NORMAL = "normal"
|
||||||
|
|
||||||
|
/** String name for "l-bfgs". */
|
||||||
|
private[regression] val LBFGS = "l-bfgs"
|
||||||
|
|
||||||
|
/** Set of solvers that LinearRegression supports. */
|
||||||
|
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1265,8 +1265,8 @@ class NaiveBayesModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaML
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
||||||
HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable,
|
HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
|
||||||
JavaMLReadable):
|
JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
Classifier trainer based on the Multilayer Perceptron.
|
Classifier trainer based on the Multilayer Perceptron.
|
||||||
Each layer has sigmoid activation function, output layer has softmax.
|
Each layer has sigmoid activation function, output layer has softmax.
|
||||||
|
@ -1407,20 +1407,6 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
|
||||||
"""
|
"""
|
||||||
return self.getOrDefault(self.stepSize)
|
return self.getOrDefault(self.stepSize)
|
||||||
|
|
||||||
@since("2.0.0")
|
|
||||||
def setSolver(self, value):
|
|
||||||
"""
|
|
||||||
Sets the value of :py:attr:`solver`.
|
|
||||||
"""
|
|
||||||
return self._set(solver=value)
|
|
||||||
|
|
||||||
@since("2.0.0")
|
|
||||||
def getSolver(self):
|
|
||||||
"""
|
|
||||||
Gets the value of solver or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.solver)
|
|
||||||
|
|
||||||
@since("2.0.0")
|
@since("2.0.0")
|
||||||
def setInitialWeights(self, value):
|
def setInitialWeights(self, value):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
|
||||||
.. versionadded:: 1.4.0
|
.. versionadded:: 1.4.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
|
||||||
|
"options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
@keyword_only
|
@keyword_only
|
||||||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
||||||
|
@ -1371,6 +1374,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
|
||||||
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
|
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
|
||||||
"Only applicable to the Tweedie family.",
|
"Only applicable to the Tweedie family.",
|
||||||
typeConverter=TypeConverters.toFloat)
|
typeConverter=TypeConverters.toFloat)
|
||||||
|
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
|
||||||
|
"options: irls.", typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
@keyword_only
|
@keyword_only
|
||||||
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
|
||||||
|
|
Loading…
Reference in a new issue