[SPARK-15339][ML] ML 2.0 QA: Scala APIs and code audit for regression

## What changes were proposed in this pull request?
* ```GeneralizedLinearRegression``` API docs enhancement.
* The default value of ```GeneralizedLinearRegression``` ```linkPredictionCol``` is not set rather than empty. This will consistent with other similar params such as ```weightCol```
* Make some methods more private.
* Fix a minor bug of LinearRegression.
* Fix some other issues.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #13129 from yanboliang/spark-15339.
This commit is contained in:
Yanbo Liang 2016-05-19 23:35:20 -07:00 committed by Xiangrui Meng
parent 5e203505f1
commit c94b34ebbf
5 changed files with 58 additions and 47 deletions

View file

@ -89,8 +89,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
def getQuantilesCol: String = $(quantilesCol)
/** Checks whether the input has quantiles column name. */
protected[regression] def hasQuantilesCol: Boolean = {
isDefined(quantilesCol) && $(quantilesCol) != ""
private[regression] def hasQuantilesCol: Boolean = {
isDefined(quantilesCol) && $(quantilesCol).nonEmpty
}
/**

View file

@ -43,6 +43,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
with HasSolver with Logging {
import GeneralizedLinearRegression._
/**
* Param for the name of family which is a description of the error distribution
* to be used in the model.
@ -54,8 +56,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the error distribution to be used in the " +
"model. Supported options: gaussian(default), binomial, poisson and gamma.",
ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray))
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
ParamValidators.inArray[String](supportedFamilyNames.toArray))
/** @group getParam */
@Since("2.0.0")
@ -71,9 +73,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
final val link: Param[String] = new Param(this, "link", "The name of link function " +
"which provides the relationship between the linear predictor and the mean of the " +
"distribution function. Supported options: identity, log, inverse, logit, probit, " +
"cloglog and sqrt.",
ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray))
s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}",
ParamValidators.inArray[String](supportedLinkNames.toArray))
/** @group getParam */
@Since("2.0.0")
@ -81,19 +82,23 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
/**
* Param for link prediction (linear predictor) column name.
* Default is empty, which means we do not output link prediction.
* Default is not set, which means we do not output link prediction.
*
* @group param
*/
@Since("2.0.0")
final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
"link prediction (linear predictor) column name")
setDefault(linkPredictionCol, "")
/** @group getParam */
@Since("2.0.0")
def getLinkPredictionCol: String = $(linkPredictionCol)
/** Checks whether we should output link prediction. */
private[regression] def hasLinkPredictionCol: Boolean = {
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
}
import GeneralizedLinearRegression._
@Since("2.0.0")
@ -107,7 +112,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
s"with ${$(family)} family does not support ${$(link)} link function.")
}
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
if ($(linkPredictionCol).nonEmpty) {
if (hasLinkPredictionCol) {
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
} else {
newSchema
@ -205,7 +210,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is empty, so all instances have weight one.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@ -214,7 +219,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
/**
* Sets the solver algorithm used for optimization.
* Currently only support "irls" which is also the default solver.
* Currently only supports "irls" which is also the default solver.
*
* @group setParam
*/
@ -239,10 +244,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
val familyAndLink = new FamilyAndLink(familyObj, linkObj)
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd
.map { case Row(features: Vector) =>
features.size
}.first()
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) {
val msg = "Currently, GeneralizedLinearRegression only supports number of features" +
s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset."
@ -294,7 +296,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def load(path: String): GeneralizedLinearRegression = super.load(path)
/** Set of family and link pairs that GeneralizedLinearRegression supports. */
private[ml] lazy val supportedFamilyAndLinkPairs = Set(
private[regression] lazy val supportedFamilyAndLinkPairs = Set(
Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse,
Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog,
Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt,
@ -302,17 +304,17 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
)
/** Set of family names that GeneralizedLinearRegression supports. */
private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name)
/** Set of link names that GeneralizedLinearRegression supports. */
private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
private[regression] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name)
private[ml] val epsilon: Double = 1E-16
private[regression] val epsilon: Double = 1E-16
/**
* Wrapper of family and link combination used in the model.
*/
private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
private[regression] class FamilyAndLink(val family: Family, val link: Link) extends Serializable {
/** Linear predictor based on given mu. */
def predict(mu: Double): Double = link.link(family.project(mu))
@ -359,7 +361,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
*
* @param name the name of the family.
*/
private[ml] abstract class Family(val name: String) extends Serializable {
private[regression] abstract class Family(val name: String) extends Serializable {
/** The default link instance of this family. */
val defaultLink: Link
@ -391,7 +393,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
def project(mu: Double): Double = mu
}
private[ml] object Family {
private[regression] object Family {
/**
* Gets the [[Family]] object from its name.
@ -412,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Gaussian exponential family distribution.
* The default link for the Gaussian family is the identity link.
*/
private[ml] object Gaussian extends Family("gaussian") {
private[regression] object Gaussian extends Family("gaussian") {
val defaultLink: Link = Identity
@ -448,7 +450,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Binomial exponential family distribution.
* The default link for the Binomial family is the logit link.
*/
private[ml] object Binomial extends Family("binomial") {
private[regression] object Binomial extends Family("binomial") {
val defaultLink: Link = Logit
@ -492,7 +494,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Poisson exponential family distribution.
* The default link for the Poisson family is the log link.
*/
private[ml] object Poisson extends Family("poisson") {
private[regression] object Poisson extends Family("poisson") {
val defaultLink: Link = Log
@ -533,7 +535,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Gamma exponential family distribution.
* The default link for the Gamma family is the inverse link.
*/
private[ml] object Gamma extends Family("gamma") {
private[regression] object Gamma extends Family("gamma") {
val defaultLink: Link = Inverse
@ -578,7 +580,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
*
* @param name the name of link function.
*/
private[ml] abstract class Link(val name: String) extends Serializable {
private[regression] abstract class Link(val name: String) extends Serializable {
/** The link function. */
def link(mu: Double): Double
@ -590,7 +592,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
def unlink(eta: Double): Double
}
private[ml] object Link {
private[regression] object Link {
/**
* Gets the [[Link]] object from its name.
@ -611,7 +613,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
}
}
private[ml] object Identity extends Link("identity") {
private[regression] object Identity extends Link("identity") {
override def link(mu: Double): Double = mu
@ -620,7 +622,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = eta
}
private[ml] object Logit extends Link("logit") {
private[regression] object Logit extends Link("logit") {
override def link(mu: Double): Double = math.log(mu / (1.0 - mu))
@ -629,7 +631,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta))
}
private[ml] object Log extends Link("log") {
private[regression] object Log extends Link("log") {
override def link(mu: Double): Double = math.log(mu)
@ -638,7 +640,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = math.exp(eta)
}
private[ml] object Inverse extends Link("inverse") {
private[regression] object Inverse extends Link("inverse") {
override def link(mu: Double): Double = 1.0 / mu
@ -647,7 +649,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = 1.0 / eta
}
private[ml] object Probit extends Link("probit") {
private[regression] object Probit extends Link("probit") {
override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu)
@ -658,7 +660,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta)
}
private[ml] object CLogLog extends Link("cloglog") {
private[regression] object CLogLog extends Link("cloglog") {
override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu))
@ -667,7 +669,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta))
}
private[ml] object Sqrt extends Link("sqrt") {
private[regression] object Sqrt extends Link("sqrt") {
override def link(mu: Double): Double = math.sqrt(mu)
@ -732,7 +734,7 @@ class GeneralizedLinearRegressionModel private[ml] (
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
if ($(linkPredictionCol).nonEmpty) {
if (hasLinkPredictionCol) {
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
}
output.toDF()
@ -860,7 +862,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
*/
@Since("2.0.0")
val predictionCol: String = {
if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") {
if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol.nonEmpty) {
origModel.getPredictionCol
} else {
"prediction_" + java.util.UUID.randomUUID.toString

View file

@ -69,8 +69,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
setDefault(isotonic -> true, featureIndex -> 0)
/** Checks whether the input has weight column. */
protected[ml] def hasWeightCol: Boolean = {
isDefined(weightCol) && $(weightCol) != ""
private[regression] def hasWeightCol: Boolean = {
isDefined(weightCol) && $(weightCol).nonEmpty
}
/**

View file

@ -161,9 +161,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
case Row(features: Vector) => features.size
}.first()
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
if (($(solver) == "auto" && $(elasticNetParam) == 0.0 &&
@ -242,7 +240,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val coefficients = Vectors.sparse(numFeatures, Seq())
val intercept = yMean
val model = new LinearRegressionModel(uid, coefficients, intercept)
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
@ -254,7 +252,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
model,
Array(0D),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
return model.setSummary(trainingSummary)
} else {
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
"Model cannot be regularized.")

View file

@ -610,20 +610,31 @@ class LinearRegressionSuite
val model1 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
.setPredictionCol("myPrediction")
.setSolver(solver)
.fit(datasetWithWeightConstantLabel)
val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
model1.coefficients(1))
assert(actual1 ~== expected(idx) absTol 1e-4)
// Schema of summary.predictions should be a superset of the input dataset
assert((datasetWithWeightConstantLabel.schema.fieldNames.toSet + model1.getPredictionCol)
.subsetOf(model1.summary.predictions.schema.fieldNames.toSet))
val model2 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
.setPredictionCol("myPrediction")
.setSolver(solver)
.fit(datasetWithWeightZeroLabel)
val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
model2.coefficients(1))
assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
// Schema of summary.predictions should be a superset of the input dataset
assert((datasetWithWeightZeroLabel.schema.fieldNames.toSet + model2.getPredictionCol)
.subsetOf(model2.summary.predictions.schema.fieldNames.toSet))
idx += 1
}
}
@ -672,7 +683,7 @@ class LinearRegressionSuite
test("linear regression model training summary") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer = new LinearRegression().setSolver(solver)
val trainer = new LinearRegression().setSolver(solver).setPredictionCol("myPrediction")
val model = trainer.fit(datasetWithDenseFeature)
val trainerNoPredictionCol = trainer.setPredictionCol("")
val modelNoPredictionCol = trainerNoPredictionCol.fit(datasetWithDenseFeature)
@ -682,7 +693,7 @@ class LinearRegressionSuite
assert(modelNoPredictionCol.hasSummary)
// Schema should be a superset of the input dataset
assert((datasetWithDenseFeature.schema.fieldNames.toSet + "prediction").subsetOf(
assert((datasetWithDenseFeature.schema.fieldNames.toSet + model.getPredictionCol).subsetOf(
model.summary.predictions.schema.fieldNames.toSet))
// Validate that we re-insert a prediction column for evaluation
val modelNoPredictionColFieldNames