[SPARK-32232][ML][PYSPARK] Make sure ML has the same default solver values between Scala and Python

# What changes were proposed in this pull request?
current problems:
```
        mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)
        model = mlp.fit(df)
        path = tempfile.mkdtemp()
        model_path = path + "/mlp"
        model.save(model_path)
        model2 = MultilayerPerceptronClassificationModel.load(model_path)
        self.assertEqual(model2.getSolver(), "l-bfgs")    # this fails because model2.getSolver() returns 'auto'
        model2.transform(df)
        # this fails with Exception pyspark.sql.utils.IllegalArgumentException: MultilayerPerceptronClassifier_dec859ed24ec parameter solver given invalid value auto.
```
FMClassifier/Regression and GeneralizedLinearRegression have the same problems.

Here are the root cause of the problems:
1. In HasSolver, both Scala and Python default solver to 'auto'

2. On Scala side, mlp overrides the default of solver to 'l-bfgs', FMClassifier/Regression overrides the default of solver to 'adamW', and glr overrides the default of solver to 'irls'

3. On Scala side, mlp overrides the default of solver in MultilayerPerceptronClassificationParams, so both MultilayerPerceptronClassification and MultilayerPerceptronClassificationModel have 'l-bfgs' as default

4. On Python side, mlp overrides the default of solver in MultilayerPerceptronClassification, so it has default as 'l-bfgs', but MultilayerPerceptronClassificationModel doesn't override the default so it gets the default from HasSolver which is 'auto'. In theory, we don't care about the solver value or any other params values for MultilayerPerceptronClassificationModel, because we have the fitted model already. That's why on Python side, we never set default values for any of the XXXModel.

5. when calling getSolver on the loaded mlp model, it calls this line of code underneath:
```
    def _transfer_params_from_java(self):
        """
        Transforms the embedded params from the companion Java object.
        """
        ......
                # SPARK-14931: Only check set params back to avoid default params mismatch.
                if self._java_obj.isSet(java_param):
                    value = _java2py(sc, self._java_obj.getOrDefault(java_param))
                    self._set(**{param.name: value})
        ......
```
that's why model2.getSolver() returns 'auto'. The code doesn't get the default Scala value (in this case 'l-bfgs') to set to Python param, so it takes the default value (in this case 'auto') on Python side.

6. when calling model2.transform(df), it calls this underneath:
```
    def _transfer_params_to_java(self):
        """
        Transforms the embedded params to the companion Java object.
        """
        ......
            if self.hasDefault(param):
                pair = self._make_java_param_pair(param, self._defaultParamMap[param])
                pair_defaults.append(pair)
        ......

```
Again, it gets the Python default solver which is 'auto', and this caused the Exception

7. Currently, on Scala side, for some of the algorithms, we set default values in the XXXParam, so both estimator and transformer get the default value. However, for some of the algorithms, we only set default in estimators, and the XXXModel doesn't get the default value. On Python side, we never set defaults for the XXXModel. This causes the default value inconsistency.

8. My proposed solution: set default params in XXXParam for both Scala and Python, so both the estimator and transformer have the same default value for both Scala and Python. I currently only changed solver in this PR. If everyone is OK with the fix, I will change all the other params as well.

I hope my explanation makes sense to your folks :)

### Why are the changes needed?
Fix bug

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing and new tests

Closes #29060 from huaxingao/solver_parity.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
Huaxin Gao 2020-07-11 10:37:26 -05:00 committed by Sean Owen
parent 22f9dfb25a
commit 99b4b06255
5 changed files with 84 additions and 28 deletions

View file

@ -112,6 +112,10 @@ private[ml] trait FactorizationMachinesParams extends PredictorParams
"The solver algorithm for optimization. Supported options: " +
s"${supportedSolvers.mkString(", ")}. (Default adamW)",
ParamValidators.inArray[String](supportedSolvers))
setDefault(factorSize -> 8, fitIntercept -> true, fitLinear -> true, regParam -> 0.0,
miniBatchFraction -> 1.0, initStd -> 0.01, maxIter -> 100, stepSize -> 1.0, tol -> 1E-6,
solver -> AdamW)
}
private[ml] trait FactorizationMachines extends FactorizationMachinesParams {
@ -308,7 +312,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setFactorSize(value: Int): this.type = set(factorSize, value)
setDefault(factorSize -> 8)
/**
* Set whether to fit intercept term.
@ -318,7 +321,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
/**
* Set whether to fit linear term.
@ -328,7 +330,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setFitLinear(value: Boolean): this.type = set(fitLinear, value)
setDefault(fitLinear -> true)
/**
* Set the L2 regularization parameter.
@ -338,7 +339,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setRegParam(value: Double): this.type = set(regParam, value)
setDefault(regParam -> 0.0)
/**
* Set the mini-batch fraction parameter.
@ -348,7 +348,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setMiniBatchFraction(value: Double): this.type = set(miniBatchFraction, value)
setDefault(miniBatchFraction -> 1.0)
/**
* Set the standard deviation of initial coefficients.
@ -358,7 +357,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setInitStd(value: Double): this.type = set(initStd, value)
setDefault(initStd -> 0.01)
/**
* Set the maximum number of iterations.
@ -368,7 +366,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
setDefault(maxIter -> 100)
/**
* Set the initial step size for the first step (like learning rate).
@ -378,7 +375,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setStepSize(value: Double): this.type = set(stepSize, value)
setDefault(stepSize -> 1.0)
/**
* Set the convergence tolerance of iterations.
@ -388,7 +384,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
/**
* Set the solver algorithm used for optimization.
@ -399,7 +394,6 @@ class FMRegressor @Since("3.0.0") (
*/
@Since("3.0.0")
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> AdamW)
/**
* Set the random seed for weight initialization.

View file

@ -181,6 +181,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
s"${supportedSolvers.mkString(", ")}. (Default irls)",
ParamValidators.inArray[String](supportedSolvers))
setDefault(family -> Gaussian.name, variancePower -> 0.0, maxIter -> 25, tol -> 1E-6,
regParam -> 0.0, solver -> IRLS)
@Since("2.0.0")
override def validateAndTransformSchema(
schema: StructType,
@ -257,7 +260,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
*/
@Since("2.0.0")
def setFamily(value: String): this.type = set(family, value)
setDefault(family -> Gaussian.name)
/**
* Sets the value of param [[variancePower]].
@ -268,7 +270,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
*/
@Since("2.2.0")
def setVariancePower(value: Double): this.type = set(variancePower, value)
setDefault(variancePower -> 0.0)
/**
* Sets the value of param [[linkPower]].
@ -305,7 +306,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
*/
@Since("2.0.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
setDefault(maxIter -> 25)
/**
* Sets the convergence tolerance of iterations.
@ -316,7 +316,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
*/
@Since("2.0.0")
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
/**
* Sets the regularization parameter for L2 regularization.
@ -332,7 +331,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
*/
@Since("2.0.0")
def setRegParam(value: Double): this.type = set(regParam, value)
setDefault(regParam -> 0.0)
/**
* Sets the value of param [[weightCol]].
@ -364,7 +362,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
*/
@Since("2.0.0")
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> IRLS)
/**
* Sets the link prediction (linear predictor) column name.

View file

@ -2421,6 +2421,10 @@ class _MultilayerPerceptronParams(_ProbabilisticClassifierParams, HasSeed, HasMa
initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.",
typeConverter=TypeConverters.toVector)
def __init__(self):
super(_MultilayerPerceptronParams, self).__init__()
self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
@since("1.6.0")
def getLayers(self):
"""
@ -2524,7 +2528,6 @@ class MultilayerPerceptronClassifier(_JavaProbabilisticClassifier, _MultilayerPe
super(MultilayerPerceptronClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
self._setDefault(maxIter=100, tol=1E-6, blockSize=128, stepSize=0.03, solver="l-bfgs")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -3120,9 +3123,6 @@ class FMClassifier(_JavaProbabilisticClassifier, _FactorizationMachinesParams, J
super(FMClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.FMClassifier", self.uid)
self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
tol=1e-6, solver="adamW")
kwargs = self._input_kwargs
self.setParams(**kwargs)

View file

@ -1891,6 +1891,11 @@ class _GeneralizedLinearRegressionParams(_PredictorParams, HasFitIntercept, HasM
"or empty, we treat all instance offsets as 0.0",
typeConverter=TypeConverters.toString)
def __init__(self):
super(_GeneralizedLinearRegressionParams, self).__init__()
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
variancePower=0.0, aggregationDepth=2)
@since("2.0.0")
def getFamily(self):
"""
@ -2023,8 +2028,6 @@ class GeneralizedLinearRegression(_JavaRegressor, _GeneralizedLinearRegressionPa
super(GeneralizedLinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
variancePower=0.0, aggregationDepth=2)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -2398,6 +2401,12 @@ class _FactorizationMachinesParams(_PredictorParams, HasMaxIter, HasStepSize, Ha
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
"options: gd, adamW. (Default adamW)", typeConverter=TypeConverters.toString)
def __init__(self):
super(_FactorizationMachinesParams, self).__init__()
self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
tol=1e-6, solver="adamW")
@since("3.0.0")
def getFactorSize(self):
"""
@ -2489,9 +2498,6 @@ class FMRegressor(_JavaRegressor, _FactorizationMachinesParams, JavaMLWritable,
super(FMRegressor, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.FMRegressor", self.uid)
self._setDefault(factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0,
miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0,
tol=1e-6, solver="adamW")
kwargs = self._input_kwargs
self.setParams(**kwargs)

View file

@ -21,19 +21,78 @@ import tempfile
import unittest
from pyspark.ml import Transformer
from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \
OneVsRestModel
from pyspark.ml.classification import DecisionTreeClassifier, FMClassifier, \
FMClassificationModel, LogisticRegression, MultilayerPerceptronClassifier, \
MultilayerPerceptronClassificationModel, OneVsRest, OneVsRestModel
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import Binarizer, HashingTF, PCA
from pyspark.ml.linalg import Vectors
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression
from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \
GeneralizedLinearRegressionModel, \
LinearRegression
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter
from pyspark.ml.wrapper import JavaParams
from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase
class TestDefaultSolver(SparkSessionTestCase):
def test_multilayer_load(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense([0.0, 0.0])),
(1.0, Vectors.dense([0.0, 1.0])),
(1.0, Vectors.dense([1.0, 0.0])),
(0.0, Vectors.dense([1.0, 1.0]))],
["label", "features"])
mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)
model = mlp.fit(df)
self.assertEqual(model.getSolver(), "l-bfgs")
transformed1 = model.transform(df)
path = tempfile.mkdtemp()
model_path = path + "/mlp"
model.save(model_path)
model2 = MultilayerPerceptronClassificationModel.load(model_path)
self.assertEqual(model2.getSolver(), "l-bfgs")
transformed2 = model2.transform(df)
self.assertEqual(transformed1.take(4), transformed2.take(4))
def test_fm_load(self):
df = self.spark.createDataFrame([(1.0, Vectors.dense(1.0)),
(0.0, Vectors.sparse(1, [], []))],
["label", "features"])
fm = FMClassifier(factorSize=2, maxIter=50, stepSize=2.0)
model = fm.fit(df)
self.assertEqual(model.getSolver(), "adamW")
transformed1 = model.transform(df)
path = tempfile.mkdtemp()
model_path = path + "/fm"
model.save(model_path)
model2 = FMClassificationModel.load(model_path)
self.assertEqual(model2.getSolver(), "adamW")
transformed2 = model2.transform(df)
self.assertEqual(transformed1.take(2), transformed2.take(2))
def test_glr_load(self):
df = self.spark.createDataFrame([(1.0, Vectors.dense(0.0, 0.0)),
(1.0, Vectors.dense(1.0, 2.0)),
(2.0, Vectors.dense(0.0, 0.0)),
(2.0, Vectors.dense(1.0, 1.0))],
["label", "features"])
glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
model = glr.fit(df)
self.assertEqual(model.getSolver(), "irls")
transformed1 = model.transform(df)
path = tempfile.mkdtemp()
model_path = path + "/glr"
model.save(model_path)
model2 = GeneralizedLinearRegressionModel.load(model_path)
self.assertEqual(model2.getSolver(), "irls")
transformed2 = model2.transform(df)
self.assertEqual(transformed1.take(4), transformed2.take(4))
class PersistenceTest(SparkSessionTestCase):
def test_linear_regression(self):