From c94b34ebbf4c6ce353c899c571beb34e8db98917 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 19 May 2016 23:35:20 -0700 Subject: [PATCH] [SPARK-15339][ML] ML 2.0 QA: Scala APIs and code audit for regression ## What changes were proposed in this pull request? * ```GeneralizedLinearRegression``` API docs enhancement. * The default value of ```GeneralizedLinearRegression``` ```linkPredictionCol``` is not set rather than empty. This will consistent with other similar params such as ```weightCol``` * Make some methods more private. * Fix a minor bug of LinearRegression. * Fix some other issues. ## How was this patch tested? Existing tests. Author: Yanbo Liang Closes #13129 from yanboliang/spark-15339. --- .../ml/regression/AFTSurvivalRegression.scala | 4 +- .../GeneralizedLinearRegression.scala | 74 ++++++++++--------- .../ml/regression/IsotonicRegression.scala | 4 +- .../ml/regression/LinearRegression.scala | 8 +- .../ml/regression/LinearRegressionSuite.scala | 15 +++- 5 files changed, 58 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index cc16c2f038..e63eb71080 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -89,8 +89,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params def getQuantilesCol: String = $(quantilesCol) /** Checks whether the input has quantiles column name. */ - protected[regression] def hasQuantilesCol: Boolean = { - isDefined(quantilesCol) && $(quantilesCol) != "" + private[regression] def hasQuantilesCol: Boolean = { + isDefined(quantilesCol) && $(quantilesCol).nonEmpty } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index e8474d035e..adbdd345e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -43,6 +43,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol with HasSolver with Logging { + import GeneralizedLinearRegression._ + /** * Param for the name of family which is a description of the error distribution * to be used in the model. @@ -54,8 +56,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the error distribution to be used in the " + - "model. Supported options: gaussian(default), binomial, poisson and gamma.", - ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray)) + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", + ParamValidators.inArray[String](supportedFamilyNames.toArray)) /** @group getParam */ @Since("2.0.0") @@ -71,9 +73,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") final val link: Param[String] = new Param(this, "link", "The name of link function " + "which provides the relationship between the linear predictor and the mean of the " + - "distribution function. Supported options: identity, log, inverse, logit, probit, " + - "cloglog and sqrt.", - ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray)) + s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", + ParamValidators.inArray[String](supportedLinkNames.toArray)) /** @group getParam */ @Since("2.0.0") @@ -81,19 +82,23 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for link prediction (linear predictor) column name. - * Default is empty, which means we do not output link prediction. + * Default is not set, which means we do not output link prediction. * * @group param */ @Since("2.0.0") final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol", "link prediction (linear predictor) column name") - setDefault(linkPredictionCol, "") /** @group getParam */ @Since("2.0.0") def getLinkPredictionCol: String = $(linkPredictionCol) + /** Checks whether we should output link prediction. */ + private[regression] def hasLinkPredictionCol: Boolean = { + isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty + } + import GeneralizedLinearRegression._ @Since("2.0.0") @@ -107,7 +112,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam s"with ${$(family)} family does not support ${$(link)} link function.") } val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) - if ($(linkPredictionCol).nonEmpty) { + if (hasLinkPredictionCol) { SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) } else { newSchema @@ -205,7 +210,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the value of param [[weightCol]]. * If this is not set or empty, we treat all instance weights as 1.0. - * Default is empty, so all instances have weight one. + * Default is not set, so all instances have weight one. * * @group setParam */ @@ -214,7 +219,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the solver algorithm used for optimization. - * Currently only support "irls" which is also the default solver. + * Currently only supports "irls" which is also the default solver. * * @group setParam */ @@ -239,10 +244,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val familyAndLink = new FamilyAndLink(familyObj, linkObj) - val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd - .map { case Row(features: Vector) => - features.size - }.first() + val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { val msg = "Currently, GeneralizedLinearRegression only supports number of features" + s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset." @@ -294,7 +296,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def load(path: String): GeneralizedLinearRegression = super.load(path) /** Set of family and link pairs that GeneralizedLinearRegression supports. */ - private[ml] lazy val supportedFamilyAndLinkPairs = Set( + private[regression] lazy val supportedFamilyAndLinkPairs = Set( Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, @@ -302,17 +304,17 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine ) /** Set of family names that GeneralizedLinearRegression supports. */ - private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) /** Set of link names that GeneralizedLinearRegression supports. */ - private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + private[regression] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) - private[ml] val epsilon: Double = 1E-16 + private[regression] val epsilon: Double = 1E-16 /** * Wrapper of family and link combination used in the model. */ - private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { + private[regression] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { /** Linear predictor based on given mu. */ def predict(mu: Double): Double = link.link(family.project(mu)) @@ -359,7 +361,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * * @param name the name of the family. */ - private[ml] abstract class Family(val name: String) extends Serializable { + private[regression] abstract class Family(val name: String) extends Serializable { /** The default link instance of this family. */ val defaultLink: Link @@ -391,7 +393,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def project(mu: Double): Double = mu } - private[ml] object Family { + private[regression] object Family { /** * Gets the [[Family]] object from its name. @@ -412,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Gaussian exponential family distribution. * The default link for the Gaussian family is the identity link. */ - private[ml] object Gaussian extends Family("gaussian") { + private[regression] object Gaussian extends Family("gaussian") { val defaultLink: Link = Identity @@ -448,7 +450,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Binomial exponential family distribution. * The default link for the Binomial family is the logit link. */ - private[ml] object Binomial extends Family("binomial") { + private[regression] object Binomial extends Family("binomial") { val defaultLink: Link = Logit @@ -492,7 +494,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Poisson exponential family distribution. * The default link for the Poisson family is the log link. */ - private[ml] object Poisson extends Family("poisson") { + private[regression] object Poisson extends Family("poisson") { val defaultLink: Link = Log @@ -533,7 +535,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Gamma exponential family distribution. * The default link for the Gamma family is the inverse link. */ - private[ml] object Gamma extends Family("gamma") { + private[regression] object Gamma extends Family("gamma") { val defaultLink: Link = Inverse @@ -578,7 +580,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * * @param name the name of link function. */ - private[ml] abstract class Link(val name: String) extends Serializable { + private[regression] abstract class Link(val name: String) extends Serializable { /** The link function. */ def link(mu: Double): Double @@ -590,7 +592,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def unlink(eta: Double): Double } - private[ml] object Link { + private[regression] object Link { /** * Gets the [[Link]] object from its name. @@ -611,7 +613,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine } } - private[ml] object Identity extends Link("identity") { + private[regression] object Identity extends Link("identity") { override def link(mu: Double): Double = mu @@ -620,7 +622,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = eta } - private[ml] object Logit extends Link("logit") { + private[regression] object Logit extends Link("logit") { override def link(mu: Double): Double = math.log(mu / (1.0 - mu)) @@ -629,7 +631,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) } - private[ml] object Log extends Link("log") { + private[regression] object Log extends Link("log") { override def link(mu: Double): Double = math.log(mu) @@ -638,7 +640,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = math.exp(eta) } - private[ml] object Inverse extends Link("inverse") { + private[regression] object Inverse extends Link("inverse") { override def link(mu: Double): Double = 1.0 / mu @@ -647,7 +649,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = 1.0 / eta } - private[ml] object Probit extends Link("probit") { + private[regression] object Probit extends Link("probit") { override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) @@ -658,7 +660,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) } - private[ml] object CLogLog extends Link("cloglog") { + private[regression] object CLogLog extends Link("cloglog") { override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu)) @@ -667,7 +669,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) } - private[ml] object Sqrt extends Link("sqrt") { + private[regression] object Sqrt extends Link("sqrt") { override def link(mu: Double): Double = math.sqrt(mu) @@ -732,7 +734,7 @@ class GeneralizedLinearRegressionModel private[ml] ( if ($(predictionCol).nonEmpty) { output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - if ($(linkPredictionCol).nonEmpty) { + if (hasLinkPredictionCol) { output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) } output.toDF() @@ -860,7 +862,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") val predictionCol: String = { - if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") { + if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol.nonEmpty) { origModel.getPredictionCol } else { "prediction_" + java.util.UUID.randomUUID.toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index ba0f59e89b..d16e8e3f6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -69,8 +69,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures setDefault(isotonic -> true, featureIndex -> 0) /** Checks whether the input has weight column. */ - protected[ml] def hasWeightCol: Boolean = { - isDefined(weightCol) && $(weightCol) != "" + private[regression] def hasWeightCol: Boolean = { + isDefined(weightCol) && $(weightCol).nonEmpty } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index a702f02c91..ff1038cbf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -161,9 +161,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. - val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { - case Row(features: Vector) => features.size - }.first() + val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && @@ -242,7 +240,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val coefficients = Vectors.sparse(numFeatures, Seq()) val intercept = yMean - val model = new LinearRegressionModel(uid, coefficients, intercept) + val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() @@ -254,7 +252,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), Array(0D)) - return copyValues(model.setSummary(trainingSummary)) + return model.setSummary(trainingSummary) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 332d331a47..265f2f45c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -610,20 +610,31 @@ class LinearRegressionSuite val model1 = new LinearRegression() .setFitIntercept(fitIntercept) .setWeightCol("weight") + .setPredictionCol("myPrediction") .setSolver(solver) .fit(datasetWithWeightConstantLabel) val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0), model1.coefficients(1)) assert(actual1 ~== expected(idx) absTol 1e-4) + // Schema of summary.predictions should be a superset of the input dataset + assert((datasetWithWeightConstantLabel.schema.fieldNames.toSet + model1.getPredictionCol) + .subsetOf(model1.summary.predictions.schema.fieldNames.toSet)) + val model2 = new LinearRegression() .setFitIntercept(fitIntercept) .setWeightCol("weight") + .setPredictionCol("myPrediction") .setSolver(solver) .fit(datasetWithWeightZeroLabel) val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0), model2.coefficients(1)) assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4) + + // Schema of summary.predictions should be a superset of the input dataset + assert((datasetWithWeightZeroLabel.schema.fieldNames.toSet + model2.getPredictionCol) + .subsetOf(model2.summary.predictions.schema.fieldNames.toSet)) + idx += 1 } } @@ -672,7 +683,7 @@ class LinearRegressionSuite test("linear regression model training summary") { Seq("auto", "l-bfgs", "normal").foreach { solver => - val trainer = new LinearRegression().setSolver(solver) + val trainer = new LinearRegression().setSolver(solver).setPredictionCol("myPrediction") val model = trainer.fit(datasetWithDenseFeature) val trainerNoPredictionCol = trainer.setPredictionCol("") val modelNoPredictionCol = trainerNoPredictionCol.fit(datasetWithDenseFeature) @@ -682,7 +693,7 @@ class LinearRegressionSuite assert(modelNoPredictionCol.hasSummary) // Schema should be a superset of the input dataset - assert((datasetWithDenseFeature.schema.fieldNames.toSet + "prediction").subsetOf( + assert((datasetWithDenseFeature.schema.fieldNames.toSet + model.getPredictionCol).subsetOf( model.summary.predictions.schema.fieldNames.toSet)) // Validate that we re-insert a prediction column for evaluation val modelNoPredictionColFieldNames